function ok = A1_cost_estimation(i_k)

% A1_cost_estimation
%==========================================================================

% Description: This code infers the cost parameters of the model:
% - Market-sector-specific fixed cost
% - Producer-market-sector-specific marginal cost
% Each product category should ideally be run on a server

%==========================================================================


% Import Data and Estimates for product k
%==========================================================================

i_k = str2num(i_k)

load('main_data_supply.mat');

k = index_price_coeff(i_k,1);    
data_k = data_summary{k};
numProdsTotal = size(data_k,1);
X_MPEC_k = X_MPEC_summary{k};
X_MPEC_opt = X_MPEC_opt_summary{k};
share_k = data_k(:,5);
price_k = data_k(:,7) + 1;
gdp_capita_k = data_k(:,13);
distance_k = data_k(:,14);
population_k = data_k(:,16);
population_log_k = log(population_k);
Y_k = Y_summary{k};
FE_k = FE_summary{k};
market_id_k = data_k(:,end);
T = market_id_k(end,1);
nn = size(Y_k,2);
alpha_constant = X_MPEC_opt(1,1);
alpha_distance = X_MPEC_opt(2,1);
alpha_population = X_MPEC_opt(3,1);
alpha_price = X_MPEC_opt(4,1);
alpha_1 = X_MPEC_opt(12,1);
alpha_FE = X_MPEC_opt((3*(K_A+1) + 1):(3*(K_A+1) + (K - K_A)), 1);

xi_k = X_MPEC_opt((3*(K_A+1) + (K - K_A) + 2 + 1):(3*(K_A+1) + (K - K_A) + 2 + numProdsTotal), 1) + 10;     % Ensure positive values
xi_k = xi_k - alpha_constant - log(price_k)*alpha_price - alpha_population*population_log_k - alpha_distance*distance_k;
q_k = alpha_constant + xi_k + FE_k*alpha_FE;


prodsMarket = zeros(T,1);
prodsMarket_temp = data_k(:,17);
for t = 1:T
prodsMarket_temp2 = prodsMarket_temp(market_id_k==t);
prodsMarket(t,1) = prodsMarket_temp2(1,1);
end

marketStarts = zeros(T,1);
marketEnds = zeros(T,1);
marketStarts(1) = 1;
marketEnds(1) = prodsMarket(1);
for t=2:T
    marketStarts(t) = marketEnds(t-1) + 1;
    marketEnds(t) = marketStarts(t) + prodsMarket(t) - 1;
end

marketForProducts = zeros(numProdsTotal,1);
for t=1:T
marketForProducts(marketStarts(t):marketEnds(t)) = t;
end
numProdsTotal = size(data_k, 1);

% Compute Market Shares
%--------------------------------------------------------------------------

delta_k =  alpha_constant + xi_k + log(price_k)*alpha_price + alpha_population*population_log_k + alpha_distance*distance_k;

for t = 1:T   
Y_market = Y_k(t,:);
expmu(marketStarts(t):marketEnds(t),:) = exp(log(price_k(marketStarts(t):marketEnds(t),:))*alpha_1*Y_market + repmat(FE_k(marketStarts(t):marketEnds(t),:)*alpha_FE,1,nn));
end
expmeanval = exp(delta_k);
oo = ones(1,nn);                  
sharesum = sparse(zeros(T,numProdsTotal));  % used to create denominators in logit predicted shares (i.e. sums numerators)
for t = 1:T
    sharesum(t,marketStarts(t):marketEnds(t)) = 1;
end
numer = (expmeanval*oo ).*expmu;        
sum1 = sharesum*numer;
sum11 = 1./(0+sum1);                                                       % Will not result in share_k = EstShare_true
denom1 = sum11(marketForProducts,:);    
simShare_true = numer.*denom1;              
EstShare_true = mean(simShare_true,2);  

%==========================================================================


% 1. Estimate Fixed Cost
%==========================================================================

% Import data on number of firms
%--------------------------------------------------------------------------

load('companies_k.mat');
companies_k = company_summary{k};
f_fval = [];

% Infer fixed cost for each market i_f
%--------------------------------------------------------------------------

for i_f = 1:T
    
index_market = market_id_k == i_f;

N = table2array(companies_k(:, 4));
N(N < 1) = 1;
N = N(index_market);
N = N(1,1);
N_0 = N;

data_market = data_k(index_market,:);
pos_dom = find(data_market(:, 3) == data_market(:, 4));
if size(pos_dom,1) == 0
    pos_dom = size(data_market,1);
    N = table2array(companies_k(1, 5));
    N_0 = N;
end
numProdsTotal_market = size(data_market, 1);
prodsMarket_market = prodsMarket(i_f,1);

marketStarts_market = 1;
marketEnds_market = prodsMarket_market;
marketForProducts_market = ones(prodsMarket_market, 1);

share_k_market = share_k(index_market,:);
price_k_market = price_k(index_market,:);
FE_k_market = FE_k(index_market,:);
population_log_k_market = population_log_k(index_market,:);
distance_k_market = distance_k(index_market,:);
xi_k_market = xi_k(index_market,:);
delta_market =  alpha_constant + xi_k_market + log(price_k_market)*alpha_price + alpha_population*population_log_k_market + alpha_distance*distance_k_market;
Y_market = Y_k(i_f,:);
expmu_market = exp(log(price_k_market)*alpha_1*Y_market + repmat(FE_k_market*alpha_FE,1,nn));
expmeanval_market = exp(delta_market);
oo = ones(1,nn);                  
sharesum_market = sparse(zeros(1,size(data_market,1)));
sharesum_market(1,:) = 1;
numer = (expmeanval_market*oo ).*expmu_market;        
sum1 = sharesum_market*numer;
sum11 = 1./(0+sum1);                      
denom1 = sum11(marketForProducts_market,:);    
simShare_true = numer.*denom1;              
EstShare_true = mean(simShare_true,2); 

Y_market = Y_k(i_f,:);
Y_market = repmat(Y_market, prodsMarket_market, 1);

% CHOOSE START VALUES
mc_0 = 0.9*price_k_market;
mc = mc_0;
f_0 = 0.1;
f = f_0;
mc_f_0 = [mc_0; f_0];
mc_f = mc_f_0;
expenditure_k_market = 1;

% FIND f_opt
mc_f_opt_handle = @(mc_f) A1a_mc_f_opt(mc_f, EstShare_true, price_k_market, expenditure_k_market, alpha_price, alpha_1, Y_market, simShare_true, nn, N, pos_dom);
% Option 1: Solve via fsolve():
options = optimoptions('fsolve', 'Display', 'none');
[mc_f_opt, fval, flag] = fsolve(mc_f_opt_handle, mc_f_0, options);
% Option 2: Solve using Artelys Knitro (adjust line below to if flag ~= 0):
%[mc_f_opt, fval, flag] = knitro_nlneqs(mc_f_opt_handle, mc_f_0);
mc_f = mc_f_opt;
mc_opt = mc_f_opt(1:(end-1),1);
f_opt = mc_f_opt(end,1);


if flag ~= 1
%if flag ~= 1 && flag ~= 2
while flag ~= 1
%while flag ~= 1 && flag ~= 2 && flag ~= 3    
mc_0 = rand(1,1);
mc_0 = mc_0.*price_k_market;
mc = mc_0;
f_0 = rand(1,1);
f = f_0;
mc_f_0 = [mc_0; f_0];
mc_f = mc_f_0;
expenditure_k_market = 1;

options = optimoptions('fsolve', 'Display', 'none');
mc_f_opt_handle = @(mc_f) A1a_mc_f_opt(mc_f, EstShare_true, price_k_market, expenditure_k_market, alpha_price, alpha_1, Y_market, simShare_true, nn, N, pos_dom);
[mc_f_opt, fval, flag] = fsolve(mc_f_opt_handle, mc_f_0, options);
%[mc_f_opt, fval, flag] = knitro_nlneqs(mc_f_opt_handle, mc_f_0);   %Knitro
mc_f = mc_f_opt;
mc_opt = mc_f_opt(1:(end-1),1);
f_opt = mc_f_opt(end,1);
price_k_market ./ mc_opt;
end
end

f_summary(i_f, 1) = f_opt;
flag_summary(i_f, 1) = flag;
N_0_summary(i_f, 1) = N_0;
f_fval = [f_fval; fval(1:(end-1))]; 

end

f_flag = flag_summary(marketForProducts);
clearvars -except f_summary i_k f_flag f_fval KK N_0_summary

%==========================================================================


% 2. Estimate Marginal Cost mc_opt
%==========================================================================

load('main_data_supply.mat');

% Import Data and Estimates for product k
%--------------------------------------------------------------------------

k = index_price_coeff(i_k,1);    

data_k = data_summary{k};
numProdsTotal = size(data_k,1);
X_MPEC_k = X_MPEC_summary{k};
X_MPEC_opt = X_MPEC_opt_summary{k};
share_k = data_k(:,5);
price_k = data_k(:,7) + 1;
gdp_capita_k = data_k(:,13);
distance_k = data_k(:,14);
population_k = data_k(:,16);
population_log_k = log(population_k);
Y_k = Y_summary{k};
FE_k = FE_summary{k};
market_id_k = data_k(:,end);
T = market_id_k(end,1);
nn = size(Y_k,2);
alpha_constant = X_MPEC_opt(1,1);
alpha_distance = X_MPEC_opt(2,1);
alpha_population = X_MPEC_opt(3,1);
alpha_price = X_MPEC_opt(4,1);
alpha_1 = X_MPEC_opt(12,1);
alpha_FE = X_MPEC_opt((3*(K_A+1) + 1):(3*(K_A+1) + (K - K_A)), 1);

xi_k = X_MPEC_opt((3*(K_A+1) + (K - K_A) + 2 + 1):(3*(K_A+1) + (K - K_A) + 2 + numProdsTotal), 1) + 10;     % Ensure positive values
xi_k = xi_k - alpha_constant - log(price_k)*alpha_price - alpha_population*population_log_k - alpha_distance*distance_k;
q_k = alpha_constant + xi_k + FE_k*alpha_FE;


prodsMarket = zeros(T,1);
prodsMarket_temp = data_k(:,17);
for t = 1:T
prodsMarket_temp2 = prodsMarket_temp(market_id_k==t);
prodsMarket(t,1) = prodsMarket_temp2(1,1);
end

marketStarts = zeros(T,1);
marketEnds = zeros(T,1);
marketStarts(1) = 1;
marketEnds(1) = prodsMarket(1);
for t=2:T
    marketStarts(t) = marketEnds(t-1) + 1;
    marketEnds(t) = marketStarts(t) + prodsMarket(t) - 1;
end

marketForProducts = zeros(numProdsTotal,1);
for t=1:T
marketForProducts(marketStarts(t):marketEnds(t)) = t;
end
numProdsTotal = size(data_k, 1);


delta_k =  alpha_constant + xi_k + log(price_k)*alpha_price + alpha_population*population_log_k + alpha_distance*distance_k;

for t = 1:T   
Y_market = Y_k(t,:);
expmu(marketStarts(t):marketEnds(t),:) = exp(log(price_k(marketStarts(t):marketEnds(t),:))*alpha_1*Y_market + repmat(FE_k(marketStarts(t):marketEnds(t),:)*alpha_FE,1,nn));
end
expmeanval = exp(delta_k);
oo = ones(1,nn);                  
sharesum = sparse(zeros(T,numProdsTotal));  % used to create denominators in logit predicted shares (i.e. sums numerators)
for t = 1:T
    sharesum(t,marketStarts(t):marketEnds(t)) = 1;
end
numer = (expmeanval*oo ).*expmu;        
sum1 = sharesum*numer;
sum11 = 1./(0+sum1);                                                       % Will not result in share_k = EstShare_true
denom1 = sum11(marketForProducts,:);    
simShare_true = numer.*denom1;              
EstShare_true = mean(simShare_true,2);

% Solve First-Order Conditions
%--------------------------------------------------------------------------

expenditure_k_market = 1;

options = optimoptions('fsolve', 'TolFun', 1e-15, 'MaxFunEvals', 1000, 'Display', 'none');
mc_opt_summary = zeros(numProdsTotal, 3);
for i = 1:numProdsTotal
t = market_id_k(i,1);
f_opt = f_summary(t,1);
mc = ones(prodsMarket(t,1),1);
mc_opt_0 = 0.9 * price_k(i,1);
fval_0 = A1b_mc(mc_opt_0, mc, i, market_id_k, Y_k, prodsMarket, EstShare_true, marketStarts, marketEnds, price_k, expenditure_k_market, f_opt, alpha_price, alpha_1, simShare_true, nn);
if isnan(fval_0) == 0
mc_opt_handle = @(mc_opt) A1b_mc(mc_opt, mc, i, market_id_k, Y_k, prodsMarket, EstShare_true, marketStarts, marketEnds, price_k, expenditure_k_market, f_opt, alpha_price, alpha_1, simShare_true, nn);
[mc_opt_summary_temp, fval, flag] = fsolve(mc_opt_handle, mc_opt_0, options);
mc_opt_summary(i,1) = mc_opt_summary_temp;
mc_opt_summary(i,2) = fval;
mc_opt_summary(i,3) = flag;
end
end

mc_opt = mc_opt_summary(:,1);
f_vec = f_summary(marketForProducts);
N_true = (price_k - mc_opt) ./ price_k .* EstShare_true ./ f_vec;          % Actual nr of firms
q_k = alpha_constant + xi_k + FE_k*alpha_FE - log(N_true);                 % Actual quality per firm

%==========================================================================


% 3. Derive Cost Parameters and Counterfactual prices, q's, and mc's
%==========================================================================

% Determine m1, new mc, and new quality
%--------------------------------------------------------------------------

m1 = zeros(numProdsTotal, 1);
for t = 1:T
N_j = EstShare_true(marketStarts(t,1):marketEnds(t,1),1) .* (price_k(marketStarts(t,1):marketEnds(t,1),1) - mc_opt_summary(marketStarts(t,1):marketEnds(t,1),1)) ./ price_k(marketStarts(t,1):marketEnds(t,1),1) .* expenditure_k_market ./ f_vec(marketStarts(t,1):marketEnds(t,1),1);
sum_s = simShare_true(marketStarts(t,1):marketEnds(t,1),:);
sum_s2 = (simShare_true(marketStarts(t,1):marketEnds(t,1),:).^2);
Y_market = Y_k(t,:);
Y_market = repmat(Y_market, prodsMarket(t,1), 1);
alpha_i = alpha_price + alpha_1*Y_market;
alpha_mean = mean(alpha_i, 2);
sum_alpha_s = alpha_i.*simShare_true(marketStarts(t,1):marketEnds(t,1),:);
sum_alpha_s2 = alpha_i.*(simShare_true(marketStarts(t,1):marketEnds(t,1),:).^2);
m1(marketStarts(t,1):marketEnds(t,1),:) = - [sum(sum_s,2)./N_j - sum(sum_s2,2)./(N_j.^2)] ./ [sum(sum_alpha_s,2)./N_j - sum(sum_alpha_s2,2)./(N_j.^2)] .* q_k(marketStarts(t,1):marketEnds(t,1),1) ./ log(mc_opt(marketStarts(t,1):marketEnds(t,1),1));
m0(marketStarts(t,1):marketEnds(t,1),:) = log(mc_opt(marketStarts(t,1):marketEnds(t,1),1)) ./ (q_k(marketStarts(t,1):marketEnds(t,1),1) .^ m1(marketStarts(t,1):marketEnds(t,1),:));
q_new(marketStarts(t,1):marketEnds(t,1),:) = [- ones(prodsMarket(t,1),1) ./ alpha_mean ./ m0(marketStarts(t,1):marketEnds(t,1),:) ./ m1(marketStarts(t,1):marketEnds(t,1),:)] .^ (ones(prodsMarket(t,1),1) ./ (m1(marketStarts(t,1):marketEnds(t,1),:) - ones(prodsMarket(t,1),1)));
mc_new(marketStarts(t,1):marketEnds(t,1),:) = exp(m0(marketStarts(t,1):marketEnds(t,1),:) .* q_new(marketStarts(t,1):marketEnds(t,1),:) .^ m1(marketStarts(t,1):marketEnds(t,1),:));
end

q_new(data_k(:,3) ~= data_k(:,4), 1) = 0;
mc_opt(data_k(:,3) ~= data_k(:,4), 1) = 0;
mc_new(data_k(:,3) ~= data_k(:,4), 1) = 0;
m1(data_k(:,3) ~= data_k(:,4), 1) = 0;
alpha_i = alpha_price + alpha_1 * Y_k(marketForProducts,:);
alpha_mean = mean(alpha_i, 2);

% Determine new price and number of firms
%--------------------------------------------------------------------------

price_new = zeros(numProdsTotal, 1);
N_new = zeros(numProdsTotal, 1);
p_flag = zeros(numProdsTotal, 1);
p_fval = zeros(numProdsTotal, 2);
options = optimoptions('fsolve', 'TolFun', 1e-7, 'MaxFunEvals', 1000, 'Display', 'none');
for i = 1:numProdsTotal  
if data_k(i,3) == data_k(i,4)      
if isnan(mc_new(i,1)) == 0 & mc_new(i,1) < 1000000
price_0 = price_k(i,1);
N_0_vec = N_0_summary(marketForProducts);
price_N_0 = [price_0; N_0_vec(i,1)];
f_opt = f_vec(i,1);
A1c_p_N_handle = @(price_N) A1c_p_N(price_N, i, mc_new, f_opt, alpha_mean);
[price_N_opt_temp, fval, flag] = fsolve(A1c_p_N_handle, price_N_0, options);
price_new(i,1) = price_N_opt_temp(1,1);
N_new(i,1) = price_N_opt_temp(2,1);
p_flag(i,1) = flag;
p_fval(i,:) = fval';
end
end
end

%==========================================================================


% 4. Save Estimates to file
%==========================================================================

% SAVE AS CSV-FILE (TABLE)
filename_costs_csv = sprintf('cost_estimates/costs_table_%i_k.csv', i_k); 
data_save = array2table([data_k(:,1:4), m1, f_vec, q_k, q_new, mc_opt, mc_new, price_k, price_new, N_true, N_new, f_flag, f_fval, mc_opt_summary(:,3), mc_opt_summary(:,2), p_flag, p_fval]);
data_save.Properties.VariableNames = {'Year' 'Quarter' 'Declarant' 'Partner', 'm1', 'f_opt', 'q_k', 'q_new', 'mc_opt', 'mc_new', 'price_k', 'price_new', 'N_true', 'N_new', 'f_flag', 'f_fval', 'mc_flag', 'mc_fval', 'p_flag', 'p_fval1', 'p_fval2'};  
writetable(data_save, filename_costs_csv, 'Delimiter',',','QuoteStrings',true);

% SAVE AS CSV-FILE (NUMERIC)
filename_costs_csv = sprintf('cost_estimates/costs_%i_k.csv', i_k);  
csvwrite(filename_costs_csv, [data_k(:,1:4), m1, f_vec, q_k, q_new, mc_opt, mc_new, price_k, price_new, N_true, N_new]);

% SAVE AS MAT-FILE
filename_costs_matlab = sprintf('cost_estimates/costs_%i_k.mat', i_k);  
save(filename_costs_matlab, 'data_k', 'm1', 'f_vec', 'q_k', 'q_new', 'mc_opt', 'mc_new', 'price_k', 'price_new', 'N_true', 'N_new', 'f_flag', 'f_fval', 'mc_opt_summary', 'p_flag', 'p_fval');  

ok = 1;

end
